Fingerprint GAN¶

Introduction¶

An open problem in the area of Latent Fingerprint Recognition is the enhancement of poor quality fingerprints for the purpose of improving fingerprint matching accuracy. Although there a lots of algorithms out there to enhance fingerprint images, the results of some of the best solutions are less than satisfactory. For this reason, my research was on the use of Generative Adversarial Networks for the purpose of enhacing images of latent fingerprints for improving matching accuracy.

This GAN was adapted from https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html. It is similar to traditional GANs.

A common problem with GANs is mode collapse. This can happen if it is trained on multi-modal data or data that contains more than one class. This is the problem that this GAN is having also. The model would appear to start producing realistic images of fingerprints, but then mode collapse will occur and nothing but a grey image gets produced.

Some common ways used to detail with mode collapse are the following:

  • Grouping the classes
  • Anticipating Counter-actions
  • Learning from Experience
  • Multiple Networks

This experiment will attempt to address this mode collapse issue by Grouping the classes in the dataset. The network architecture of the discriminator and the generator are similar to that of the Siamese Neural Network, and the Siamese Neural Network is powerful enough to determine which pairs of fingerprints are the same and which are different. Maybe the GAN's discriminator is powerfull enough to place these fingerprints into different categories. In an attempt to address this issue, an attempt will be made to group the dataset by each unique fingerprint. There are 10,000 unique fingerprints in this dataset, and each unique fingerprint has at least 10 impressions of it. In order to achieve this, the dataloader will be modified to only include one unique fingerprint in each batch.

In addition to this, a couple more changes were made. The network architeture of the generator and discriminator was modified to that in the 'Fingerprint GAN-Siamese NN Discriminator' notebook. As a result, the generator has 3 convolutional layers, followed by 2 deconvolutional layers, followed by a final convolutional layal. Then a Hyperbolic Tangent (Tanh) function is applied to produce the output. The discriminator has 5 convolutional layers followed by a sigmoid function to produce the final output.

Finally, the learing rate of the discriminator was set to b 5x lower than that of the generator.

Install Dependencies¶

In [1]:
!pip install Augmentor
!pip install pillow
!pip install seaborn
%reload_ext autoreload
%autoreload
Requirement already satisfied: Augmentor in /usr/local/lib/python3.9/dist-packages (0.2.10)
Requirement already satisfied: future>=0.16.0 in /usr/lib/python3/dist-packages (from Augmentor) (0.18.2)
Requirement already satisfied: Pillow>=5.2.0 in /usr/local/lib/python3.9/dist-packages (from Augmentor) (9.2.0)
Requirement already satisfied: numpy>=1.11.0 in /usr/local/lib/python3.9/dist-packages (from Augmentor) (1.23.1)
Requirement already satisfied: tqdm>=4.9.0 in /usr/local/lib/python3.9/dist-packages (from Augmentor) (4.64.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Requirement already satisfied: pillow in /usr/local/lib/python3.9/dist-packages (9.2.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Requirement already satisfied: seaborn in /usr/local/lib/python3.9/dist-packages (0.11.2)
Requirement already satisfied: matplotlib>=2.2 in /usr/local/lib/python3.9/dist-packages (from seaborn) (3.5.2)
Requirement already satisfied: numpy>=1.15 in /usr/local/lib/python3.9/dist-packages (from seaborn) (1.23.1)
Requirement already satisfied: pandas>=0.23 in /usr/local/lib/python3.9/dist-packages (from seaborn) (1.4.3)
Requirement already satisfied: scipy>=1.0 in /usr/local/lib/python3.9/dist-packages (from seaborn) (1.8.1)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib>=2.2->seaborn) (21.3)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.9/dist-packages (from matplotlib>=2.2->seaborn) (0.11.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib>=2.2->seaborn) (1.4.3)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib>=2.2->seaborn) (4.34.4)
Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/lib/python3.9/dist-packages (from matplotlib>=2.2->seaborn) (3.0.9)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.9/dist-packages (from matplotlib>=2.2->seaborn) (9.2.0)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.9/dist-packages (from matplotlib>=2.2->seaborn) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas>=0.23->seaborn) (2022.1)
Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.7->matplotlib>=2.2->seaborn) (1.14.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

import torchvision.utils as vutils
import torchvision.datasets as dset
import torchvision.transforms.functional as Fv

import os
import time
import math
import random
import Augmentor
import numpy as np
import random
from numpy import unravel_index 
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from random import Random
from skimage.util import random_noise

from IPython.display import HTML
import matplotlib.animation as animation

Config¶

In [43]:
target_dir   = "../../storage/Prepped_Fingerprints_206x300/Bad/"
template_dir = "../../storage/Prepped_Fingerprints_206x300/Enhanced_Good/"

model_results_file = "checkpoint/GAN_results.pt"
model_ckpt_file = "checkpoint/GAN_checkpoint.pt"

im_size = (300, 300)
var_max = 0.5

num_train = 10000
num_valid = 1000
batch_size = 64
num_workers = 1
shuffle = True
augment = True

# Number of training epochs
num_epochs = 50
start_epoch = 1

# Number of channels in the training images. For color images this is 3
nc = 1
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64

# Learning rate for optimizers
dlr = 0.00002
glr = 0.0002
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

Utils¶

In [44]:
def plotFigures(netG, test_batch):
    with torch.no_grad():
        fake_disp = netG(test_batch[0]).detach().cpu()
        img_list.append(vutils.make_grid(fake_disp, padding=2, normalize=True))
    plt.figure(figsize=(15,15))
    plt.subplot(1,1,1)
    plt.axis("off")
    plt.title("Enhnaced Images")
    plt.imshow(np.transpose(img_list[-1],(1,2,0)))
    plt.show()
    return img_list

    
def saveCkpt(filepath, epoch, netG, netD, optimizerG, optimizerD, G_losses, D_losses, img_list, test_batch, iters):
    if os.path.isfile(filepath):
        os.remove(filepath)
    torch.save({
        'epoch' : epoch,
        'netG_state_dict' : netG.state_dict(),
        'netD_state_dict' : netD.state_dict(),
        'optimizerG_state_dict' : optimizerG.state_dict(),
        'optimizerD_state_dict' : optimizerD.state_dict(),
        'G_losses' : G_losses,
        'D_losses' : D_losses,
        'img_list' : img_list,
        'test_batch' : test_batch,
        'iters' : iters,
    }, filepath)


def showImages(batch, labels=None):
    """
    Displays a set of batch images
    
    :param batch: A batch of image pairs and labels to display
    :praam labels: The labels for the images
    """
    plt.figure(figsize=(20,6))
    plt.subplot(2,1,1)
    plt.axis("off")
    plt.imshow(np.transpose(vutils.make_grid(batch[0].to(device)[:8], padding=5, normalize=True).cpu(),(1,2,0)))

    plt.subplot(2,1,2)
    plt.axis("off")
    plt.imshow(np.transpose(vutils.make_grid(batch[1].to(device)[:8], padding=5, normalize=True).cpu(),(1,2,0)))
    plt.show()
    
    if labels is not None:
        for l in labels:
            if l == 1:
                print("  same  ", end=" ")
            else:
                print("  diff  ", end=" ")
        
        
def validate(epoch):
    # switch to evaluate mode
    netD.eval()

    correct = 0
    total = 0
    for i, (val_Im1, val_Im2, val_y) in enumerate(valid_loader):
        with torch.no_grad():
            variation = random.uniform(0,var_max)
            val_Im1 = torch.tensor(random_noise(val_Im1, mode='gaussian', mean=0, var=variation, clip=True), dtype=torch.float32)
            val_Im1, val_Im2, val_y = val_Im1.to(device), val_Im2.to(device), val_y.to(device)
            batch_size = val_Im1.shape[0]

            # compute log probabilities
            pred = torch.round(netD(val_Im1, val_Im2))
            correct += (pred == val_y).sum().item()
            total += batch_size
            if total > num_valid:
                break

        # compute acc and log
        valid_acc = (100. * correct) / total
        return valid_acc
    
                
class AverageMeter(object):
    """
    Computes and stores the average and
    current value.
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
class MovingAvg(object):
    """
    Computes the moving average of values
    """
    def __init__(self, length=10):
        self.length = length
        self.movingAvg = np.array([], dtype='f')
        
    def average(self):
        return np.average(self.movingAvg)
        
    def pop(self):
        if len(self.movingAvg > 0):
            self.movingAvg = np.delete(self.movingAvg, 0, axis = 0)
    
    def push(self, val):
        self.movingAvg = np.append(self.movingAvg, [val])
        if len(self.movingAvg) > self.length:
            self.movingAvg = np.delete(self.movingAvg, 0, axis = 0)
            

Data Loader¶

In [45]:
def get_train_loader(target_dir, template_dir,
                           batch_size,
                           num_train,
                           num_valid,
                           shuffle=False,
                           num_workers=2,
                           pin_memory=False):
    """
    Utility function for loading and returning train 
    iterator over the dataset.
    If using CUDA, num_workers should be set to `1` and pin_memory to `True`.
    Args
    ----
    - target_dir: path directory to the target dataset.
    - template_dir: path directory to the template dataset.
    - batch_size: how many samples per batch to load.
    - augment: whether to load the augmented version of the train dataset.
    - num_workers: number of subprocesses to use when loading the dataset. Set
      to `1` if using GPU.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
      `True` if using GPU.
    """
    fingerprints = [str(finger) for finger in range(1,10000+1)]
    random.shuffle(fingerprints)
    training_prints = fingerprints[:10000]
    
    # Get the Training Dataloader
    train_dataset = FingerprintLoader(target_dir, template_dir, num_train, training_prints, batch_size)
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=shuffle,
        num_workers=num_workers, pin_memory=pin_memory,
    )

    return (train_loader)



def get_train_valid_loader(target_dir, template_dir,
                           batch_size,
                           num_train,
                           num_valid,
                           shuffle=False,
                           num_workers=2,
                           pin_memory=False):
    """
    Utility function for loading and returning train and valid 
    iterators over the dataset.
    If using CUDA, num_workers should be set to `1` and pin_memory to `True`.
    Args
    ----
    - target_dir: path directory to the target dataset.
    - template_dir: path directory to the template dataset.
    - batch_size: how many samples per batch to load.
    - augment: whether to load the augmented version of the train dataset.
    - num_workers: number of subprocesses to use when loading the dataset. Set
      to `1` if using GPU.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
      `True` if using GPU.
    """
    fingerprints = [str(finger) for finger in range(1,10000+1)]
    random.shuffle(fingerprints)
    training_prints = fingerprints[:7500]
    validation_prints = fingerprints[7500:]
    
    # Get the Training Dataloader
    train_dataset = FingerprintLoader(target_dir, template_dir, num_train, training_prints, batch_size)
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=shuffle,
        num_workers=num_workers, pin_memory=pin_memory,
    )

    # Get the Validation Dataloader 
    valid_dataset = FingerprintLoader(target_dir, template_dir, num_valid, validation_prints, batch_size)
    valid_loader = DataLoader(
        valid_dataset, batch_size=batch_size, shuffle=shuffle,
        num_workers=num_workers, pin_memory=pin_memory,
    )

    return (train_loader, valid_loader)
    

class FingerprintLoader(Dataset):
    """
    This class is used to help load the fingerpint dataset.
    """
    def __init__(self, target_dataset, template_dataset, num_train, dataset, batch_size):
        """
        Initializes an instance for the FingerprintLoader class.

        :param self: instance of the FingerprintLoader class
        :param template_dataset: The template fingerprint dataset
        :param target_dataset: The second fingerprint dataset to match against 
                               the template dataset
        :param num_train: The number of images to load
        :param dataset: List of fingerprints to include in the set
        """
        super(FingerprintLoader, self).__init__()
        self.target_dataset = target_dataset
        self.template_dataset = template_dataset
        self.fingerprints_dataset = dataset
        self.num_train = num_train
        self.augment = augment
        self.batch_size = batch_size
        self.batchInd = batch_size
        self.target_im_name = ""

    def __len__(self):
        """
        Helper function to return the length of the dataset

        :param self: instance of the FingerprintLoader class
        :return: the length of the dataset as an int
        """
        return self.num_train

    def __getitem__(self, index):
        """
        Getter function for accessing images from the dataset. This function will choose a 
        fingerprint image from the dataset and its corresponding enhanced fingerprint image.
        It will then preprocess the images before returning them.
        :param self: instance of the FingerprintLoader class
        :param index: index for data image in set to return
        :return: Image from dataset as a tensor
        """
        target_im_filepath, enhanced_target_im_filepath = self.chooseTargetAndTemplateImages()
        targ_im = self.preprocessImage(target_im_filepath)
        enhanced_targ_im = self.preprocessImage(enhanced_target_im_filepath)
        return targ_im, enhanced_targ_im
    
    
    def chooseTargetAndTemplateImages(self):
        """
        Returns the filepath of the target fingerprint image and the enhanced template fingerprint.
        :param self: instance of the FingerprintLoader class
        :return: The filepaths for the 
        """
        target_im_filepath = "targetim.jpg" 
        enhanced_target_im_filepath = "targetim.jpg"      
        # Chose image
        while not os.path.isfile(target_im_filepath) or not os.path.isfile(enhanced_target_im_filepath):
            target_im_filepath = self.target_dataset + random.choice(os.listdir(self.target_dataset))
            target_im_filepath +=  "/Impression_1/"
            if self.batchInd % self.batch_size == 0:
                self.target_im_name = random.choice(self.fingerprints_dataset)
            target_im_filepath = target_im_filepath + self.target_im_name + '.jpg'
            enhanced_target_im_filepath = self.template_dataset + random.choice(os.listdir(self.template_dataset)) \
                + "/Impression_1/" + self.target_im_name + '.jpg' 
        self.batchInd += 1
        return target_im_filepath, enhanced_target_im_filepath
    
    
    def preprocessImage(self, im_filepath):
        """
        Preprocesses the image. This function will open the image, convert 
        it to grayscale, pad the image in order to make is square, 
        normalize the image, and then finally convert it to a tensor.
        :param im: Filepath of the image to preprocess
        :return: The preprocessed image
        """
        im = Image.open(im_filepath)
        # Convert to Grayscale
        im = im.convert('L')
        # Pad template image
        w, h = im.size
        dim = max(w, h)
        left = int((dim - w) / 2.0)
        top = int((dim - h) / 2.0)
        image = Image.new(im.mode, (dim, dim), 255)
        image.paste(im, (left, top))
        # apply transformation 
        trans = transforms.Compose([#p.torch_transform(),
                                transforms.Resize(im_size),
                                #transforms.CenterCrop(im_size),
                                transforms.Grayscale(1),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, ), (0.5, )),
                              ])
        # Apply the transformations to the images and labels
        preprocessedImage = trans(image)
        return preprocessedImage

Data Visualization for GAN¶

It was difficult to get access to large enough fingerprint datasets for training. This is because fingerprints are considered personal information, so this data is not commonly avaiable to everyone. Because of this, I ended up synthetically generating my own dataset using Anguli. This generated dataset contains close to one million fingerprint images of varying qualities, which includes 10,000 unique fingerprints.

Below are some sample images that the GAN will be trained on. The top row represents the input fingerprint images that the generator is to enhanced. The bottom row represent the enhanced versions of the fingerprints. Please note that this bottom row just represents the ideal enhanced versions of the fingerprints from the top row. The images in this bottom row were produced by using Gabor Filters to enhance the good impressions of the fingerprints (and not necessarly the impressions shown in the top row).

The task of the generator is to enhance the images in the top row to look like Gabor Enhanced fingerprint images while also preserving the underlying ridge structure.

In [46]:
# Create the dataloader
#disp_dataset = dset.ImageFolder(root=data_dir)
disp_fingerprints = [str(finger) for finger in range(1,1000+1)]
disp_dataset = FingerprintLoader(target_dir, template_dir, num_train, disp_fingerprints, batch_size)
disp_dataloader = torch.utils.data.DataLoader(disp_dataset, batch_size=10,
                                         shuffle=shuffle, num_workers=1)

# Get a Batch of Sample Images
real_batch = next(iter(disp_dataloader))
batch = real_batch

# Display the Sample Images
plt.figure(figsize=(20,6))
plt.subplot(2,1,1)
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(batch[0].to(device)[:8], padding=5, normalize=True).cpu(),(1,2,0)))

plt.subplot(2,1,2)
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(batch[1].to(device)[:8], padding=5, normalize=True).cpu(),(1,2,0)))
plt.show()

Weight Initialization¶

In [47]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

Generator¶

In [48]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        """
        Args:
          in_channels (int):  Number of input channels.
          out_channels (int): Number of output channels.
          stride (int):       Controls the stride.
        """
        super(ResNetBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)


    def forward(self, x):
        out = torch.relu(self.bn(self.conv(x)))
        out = self.bn(self.conv(out))
        out += x
        out = F.relu(out)
        return out
    

# Generator Code
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        # Device
        self.ngpu = ngpu
        # Convolutional Layers
        self.conv1   = nn.Conv2d(nc, ngf, 7, 1, 3, bias=False)
        self.conv2   = nn.Conv2d(ngf, ngf * 2, 4, 2, 1, bias=False)
        self.conv3   = nn.Conv2d(ngf * 2, ngf * 4, 4, 2, 1, bias=False)
        self.conv4   = nn.Conv2d(ngf, nc, 7, 1, 3, bias=False)
        # Transpose Convolutional Layers
        self.deconv1 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False)
        self.deconv2 = nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False)
        # ResNet Layer
        self.resNet = ResNetBlock(ngf * 4, ngf * 4)
        # Batch Normalization Layers
        self.bn0 = nn.BatchNorm2d(ngf)
        self.bn1 = nn.BatchNorm2d(ngf * 2)
        self.bn2 = nn.BatchNorm2d(ngf * 4)
        self.bn3 = nn.BatchNorm2d(ngf * 8)


    def forward(self, x):
        out = F.relu(self.bn0(self.conv1(x)), True)
        out = F.relu(self.bn1(self.conv2(out)), True)
        out = F.relu(self.bn2(self.conv3(out)), True)
        out = F.relu(self.bn1(self.deconv1(out)), True)
        out = F.relu(self.bn0(self.deconv2(out)), True)
        out = torch.tanh(self.conv4(out))
        return out
In [49]:
# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

# Print the model
print(netG)
Generator(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
  (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv4): Conv2d(64, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
  (deconv1): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (deconv2): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (resNet): ResNetBlock(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

Discriminator¶

In [50]:
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        # Device
        self.ngpu = ngpu
        # Convolutional Layers
        self.conv1 = nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)
        self.conv2 = nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)
        self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)
        self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, 4, 1, 1, bias=False)
        self.conv5 = nn.Conv2d(ndf * 8, 1, 4, 1, 1, bias=False)
        # Batch Norm Layers
        self.bn1 = nn.BatchNorm2d(ndf * 2)
        self.bn2 = nn.BatchNorm2d(ndf * 4)
        self.bn3 = nn.BatchNorm2d(ndf * 8)
        self.bn4 = nn.BatchNorm2d(ndf * 16)
        # Fully Connected Layers
        self.fc1 = nn.Linear(3844, 1024)
        self.fc2 = nn.Linear(1225, 1)


    def forward(self, x):
        out = F.leaky_relu_(self.conv1(x), 0.2)
        out = F.leaky_relu_(self.bn1(self.conv2(out)), 0.2)
        out = F.leaky_relu_(self.bn2(self.conv3(out)), 0.2)
        out = F.leaky_relu_(self.bn3(self.conv4(out)), 0.2)
        out = self.conv5(out).view(out.shape[0], -1)
        #print(out.shape)
        #out = F.relu(self.fc1(out), 0.2)
        #print(out.shape)
        out = self.fc2(out)
        return torch.sigmoid(out)
In [51]:
# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)
Discriminator(
  (conv1): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv4): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
  (conv5): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=3844, out_features=1024, bias=True)
  (fc2): Linear(in_features=1225, out_features=1, bias=True)
)

Set Training Params¶

In [52]:
# create data loaders
torch.manual_seed(1)

kwargs = {}
if device.type == 'cuda':
    torch.cuda.manual_seed(1)
    kwargs = {'num_workers': 1, 'pin_memory': True}

# Create the dataloader
data_loader = get_train_loader(target_dir, template_dir, batch_size,num_train, num_valid, shuffle, **kwargs)

train_loader = data_loader

# Create batch of latent vectors that we will use to visualize the progression of the generator
test_loader = get_train_loader(target_dir, template_dir, 8, num_train, num_valid, shuffle, **kwargs)
test_batch = []
test = next(iter(test_loader))
test_batch.append(test[0].to(device))     
test_batch.append(test[1].to(device))

criterion = nn.BCELoss()

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

sim_label = 1.0
diff_label = 0.0

# Setup Adam optimizers for both G, D, and S
optimizerD = optim.Adam(netD.parameters(), lr=dlr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=glr, betas=(beta1, 0.999))

GAN Training¶

In [ ]:
# Train

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []

errS = 0
iters = 0
# switch to train mode
netG.train()
netD.train()
# Load checkpoint
if os.path.isfile(model_ckpt_file):
    checkpoint = torch.load(model_ckpt_file)
    netG.load_state_dict(checkpoint['netG_state_dict'])
    netD.load_state_dict(checkpoint['netD_state_dict'])
    optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
    optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
    G_losses = checkpoint['G_losses']
    D_losses = checkpoint['D_losses']
    img_list = checkpoint['img_list']
    test_batch = checkpoint['test_batch']
    iters = checkpoint['iters']
    start_epoch = checkpoint['epoch']

print("\n[*] Train on {} sample pairs".format(
    num_train, num_valid)
)

gLossMvAvg = MovingAvg()

for epoch in range(1, num_epochs+1):
    print('\nEpoch: {}/{}'.format(epoch, num_epochs)) 
    train_batch_time = AverageMeter()
    train_losses = AverageMeter()
    matching_correct = 0
    total_matched = 0
    tic = time.time()
    with tqdm(total=num_train) as pbar:
        for i, (x1, enhanced_x1) in enumerate(train_loader): 
            x1, enhanced_x1 = x1.to(device), enhanced_x1.to(device)

            ############################
            # Update D network
            ###########################
            ## Train with all-Gabor Enhanced batch
            netD.zero_grad()
            # Format batch
            b_size = enhanced_x1.size(0)
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
            # Forward pass real batch through D
            output = netD(enhanced_x1)
            output = output.view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()
            
            ## Train with all-GAN_Enhanced batch
            GAN_Enhanced = netG(x1)
            label.fill_(fake_label)
            # Classify all GAN_Enhanced batch with D
            output = netD(GAN_Enhanced.detach()).view(-1)
            # Calculate D's loss on the all-GAN_Enhanced batch
            errD_GAN_Enhanced = criterion(output, label)
            # Calculate the gradients for this batch
            errD_GAN_Enhanced.backward()
            D_G_z1 = output.mean().item()
            # Add the gradients from the all-real and all-GAN_Enhanced batches
            errD = errD_real + errD_GAN_Enhanced
            # Update D
            optimizerD.step()
            
            
            ############################
            # Update G network
            ############################
            netG.zero_grad()
            label.fill_(real_label)  # GAN_Enhanced labels are real for generator cost
            # Since we just updated D, perform another forward pass of all-GAN_Enhanced batch through D
            output = netD(GAN_Enhanced).view(-1)
            # Calculate G's loss based on this output
            errG_quality = criterion(output, label)
            # Calculate gradients for G
            errG_quality.backward()
            D_G_z2 = output.mean().item()

            # Add the gradients
            errG = errG_quality 

            # Update G
            optimizerG.step()
            
            # store batch statistics
            toc = time.time()
            train_batch_time.update(toc-tic)
            tic = time.time()
            pbar.set_description(
                (
                    "loss_D: {:.12f}  loss_G: {:.6f}".format(
                        errD.item(), errG.item()
                    )
                )
            )
            pbar.update(batch_size)
            
            gLossMvAvg.push(errG.item())
            
            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())
            
            # Check how the generator is doing by saving G's output on fixed_noise
            if (iters % 50000 == 0) or ((epoch == num_epochs-1) and (i == len(train_loader)-1)):
                with torch.no_grad():
                    fake = netG(test_batch[0]).detach().cpu()
                img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
                #saveCkpt(model_ckpt_file, epoch, netG, netD, optimizerG, optimizerD, G_losses, D_losses, img_list, test_batch, iters)

            iters += 1   
        # Plot the fake images from the last epoch
        #saveCkpt(model_ckpt_file, epoch, netG, netD, optimizerG, optimizerD, G_losses, D_losses, img_list, test_batch, iters)
        plotFigures(netG, test_batch)
# Save results
#saveCkpt(model_ckpt_file, epoch, netG, netD, optimizerG, optimizerD, G_losses, D_losses, img_list, test_batch, iters)
[*] Train on 10000 sample pairs

Epoch: 1/30
loss_D: 0.692690134048  loss_G: 2.070850: : 10048it [03:32, 63.61it/s]                        
loss_D: 0.692690134048  loss_G: 2.070850: : 10048it [03:32, 47.33it/s]
Epoch: 2/30
loss_D: 0.176055580378  loss_G: 3.421213: 100%|█████████▉| 9984/10000 [03:17<00:00, 52.41it/s]
loss_D: 0.176055580378  loss_G: 3.421213: : 10048it [03:17, 50.93it/s]                        
Epoch: 3/30
loss_D: 0.409008800983  loss_G: 3.451027: : 10048it [03:20, 60.75it/s]                        
loss_D: 0.409008800983  loss_G: 3.451027: : 10048it [03:20, 50.05it/s]
Epoch: 4/30
loss_D: 1.694708466530  loss_G: 3.923312: 100%|█████████▉| 9984/10000 [03:23<00:00, 42.89it/s]
loss_D: 1.694708466530  loss_G: 3.923312: : 10048it [03:23, 49.40it/s]                        
Epoch: 5/30
loss_D: 0.138081252575  loss_G: 5.010532: 100%|█████████▉| 9984/10000 [03:06<00:00, 61.59it/s]
loss_D: 0.138081252575  loss_G: 5.010532: : 10048it [03:06, 53.90it/s]                        
Epoch: 6/30
loss_D: 0.114352263510  loss_G: 5.659535: 100%|█████████▉| 9984/10000 [02:57<00:00, 60.39it/s]
loss_D: 0.114352263510  loss_G: 5.659535: : 10048it [02:57, 56.69it/s]                        
Epoch: 7/30
loss_D: 0.259667217731  loss_G: 3.549457: 100%|█████████▉| 9984/10000 [03:41<00:00, 43.80it/s]
loss_D: 0.259667217731  loss_G: 3.549457: : 10048it [03:42, 45.25it/s]                        
Epoch: 8/30
loss_D: 0.084674209356  loss_G: 6.660535: 100%|█████████▉| 9984/10000 [03:37<00:00, 47.44it/s]
loss_D: 0.084674209356  loss_G: 6.660535: : 10048it [03:37, 46.23it/s]                        
Epoch: 9/30
loss_D: 0.151727944613  loss_G: 5.857696: 100%|█████████▉| 9984/10000 [03:36<00:00, 43.59it/s]
loss_D: 0.151727944613  loss_G: 5.857696: : 10048it [03:36, 46.47it/s]                        
Epoch: 10/30
loss_D: 0.053639583290  loss_G: 6.835809: 100%|█████████▉| 9984/10000 [03:32<00:00, 58.75it/s] 
loss_D: 0.053639583290  loss_G: 6.835809: : 10048it [03:32, 47.33it/s]                        
Epoch: 11/30
loss_D: 0.168991118670  loss_G: 5.853852: 100%|█████████▉| 9984/10000 [03:25<00:00, 59.27it/s] 
loss_D: 0.168991118670  loss_G: 5.853852: : 10048it [03:25, 48.89it/s]                        
Epoch: 12/30
loss_D: 0.086226820946  loss_G: 7.945567: : 10048it [03:31, 50.12it/s]                         
loss_D: 0.086226820946  loss_G: 7.945567: : 10048it [03:32, 47.36it/s]
Epoch: 13/30
loss_D: 0.015403938480  loss_G: 9.593180: : 10048it [03:26, 65.07it/s]                         
loss_D: 0.015403938480  loss_G: 9.593180: : 10048it [03:26, 48.55it/s]
Epoch: 14/30
loss_D: 0.249410673976  loss_G: 6.715858: 100%|█████████▉| 9984/10000 [03:32<00:00, 44.48it/s] 
loss_D: 0.249410673976  loss_G: 6.715858: : 10048it [03:33, 47.12it/s]                        
Epoch: 15/30
loss_D: 0.033704087138  loss_G: 7.684072: 100%|█████████▉| 9984/10000 [03:33<00:00, 42.88it/s] 
loss_D: 0.033704087138  loss_G: 7.684072: : 10048it [03:33, 47.01it/s]                        
Epoch: 16/30
loss_D: 0.023600436747  loss_G: 9.432259: : 10048it [03:22, 60.73it/s]                         
loss_D: 0.023600436747  loss_G: 9.432259: : 10048it [03:22, 49.64it/s]
Epoch: 17/30
loss_D: 0.516302824020  loss_G: 5.680162: 100%|█████████▉| 9984/10000 [03:35<00:00, 44.18it/s] 
loss_D: 0.516302824020  loss_G: 5.680162: : 10048it [03:35, 46.66it/s]                        
Epoch: 18/30
loss_D: 0.000679231947  loss_G: 14.174907: : 10048it [03:30, 42.94it/s]                        
loss_D: 0.000679231947  loss_G: 14.174907: : 10048it [03:30, 47.73it/s]
Epoch: 19/30
loss_D: 0.010148795322  loss_G: 9.763084: : 10048it [03:32, 46.53it/s]                         
loss_D: 0.010148795322  loss_G: 9.763084: : 10048it [03:32, 47.31it/s]
Epoch: 20/30
loss_D: 0.066388964653  loss_G: 8.235796: 100%|█████████▉| 9984/10000 [03:26<00:00, 58.39it/s] 
loss_D: 0.066388964653  loss_G: 8.235796: : 10048it [03:27, 48.53it/s]                        
Epoch: 21/30
loss_D: 0.016905173659  loss_G: 9.714904: 100%|█████████▉| 9984/10000 [03:25<00:00, 39.69it/s] 
loss_D: 0.016905173659  loss_G: 9.714904: : 10048it [03:26, 48.70it/s]                        
Epoch: 22/30
loss_D: 0.083130896091  loss_G: 8.089863: 100%|█████████▉| 9984/10000 [03:21<00:00, 61.13it/s] 
loss_D: 0.083130896091  loss_G: 8.089863: : 10048it [03:22, 49.70it/s]                        
Epoch: 23/30
loss_D: 0.010064009577  loss_G: 16.118488: 100%|█████████▉| 9984/10000 [03:19<00:00, 57.98it/s]
loss_D: 0.010064009577  loss_G: 16.118488: : 10048it [03:19, 50.40it/s]                        
Epoch: 24/30
loss_D: 0.153829082847  loss_G: 10.549116: 100%|█████████▉| 9984/10000 [03:05<00:00, 58.75it/s]
loss_D: 0.153829082847  loss_G: 10.549116: : 10048it [03:05, 54.08it/s]                        
Epoch: 25/30
loss_D: 0.063932746649  loss_G: 12.650004: : 10048it [03:12, 71.58it/s]                        
loss_D: 0.063932746649  loss_G: 12.650004: : 10048it [03:13, 52.00it/s]
Epoch: 26/30
loss_D: 0.000878154999  loss_G: 14.005483: 100%|█████████▉| 9984/10000 [03:06<00:00, 62.61it/s]
loss_D: 0.000878154999  loss_G: 14.005483: : 10048it [03:06, 53.89it/s]                        
Epoch: 27/30
loss_D: 0.070283703506  loss_G: 9.386616:   8%|▊         | 832/10000 [00:16<02:51, 53.46it/s] 

Results¶

Training Loss¶

In [54]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

Animation of Generator Progression.¶

Below is an animation of the generator's progression throughout training. Note, for each frame, the generator was fed the same batch of sample images as input.

In [1]:
#%%capture
fig = plt.figure(figsize=(16,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Input In [1], in <cell line: 2>()
      1 #%%capture
----> 2 fig = plt.figure(figsize=(16,8))
      3 plt.axis("off")
      4 ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]

NameError: name 'plt' is not defined
In [ ]:
 

Side by Side Comparison of Real and Fake Images¶

  • Left: Real Images
  • Middle: Generated Images
  • Right: Real Images
In [56]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(test_loader))

# Plot the noisy images
plt.figure(figsize=(15,15))
plt.subplot(3,1,1)
plt.axis("off")
plt.title("Original Images")
plt.imshow(np.transpose(vutils.make_grid(test_batch[0].to(device)[:], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(3,1,2)
plt.axis("off")
plt.title("Enhnaced Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))

# Plot the real images
plt.subplot(3,1,3)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(test_batch[1].to(device)[:], padding=5, normalize=True).cpu(),(1,2,0)))
plt.show()
In [ ]: